{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 用 tf.data 加载 CSV 数据\n",
    "通过一个示例展示了怎样将 CSV 格式的数据加载进 tf.data.Dataset。\n",
    "\n",
    "使用的是泰坦尼克号乘客的数据。模型会根据乘客的年龄、性别、票务舱和是否独自旅行等特征来预测乘客生还的可能性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\anaconda3\\envs\\tensorflow\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "# 设置\n",
    "import functools\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "import tensorflow_datasets as tfds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 训练&测试数据的url\n",
    "TRAIN_DATA_URL = \"https://storage.googleapis.com/tf-datasets/titanic/train.csv\"\n",
    "TEST_DATA_URL = \"https://storage.googleapis.com/tf-datasets/titanic/eval.csv\"\n",
    "\n",
    "train_file_path = tf.keras.utils.get_file(\"train.csv\", TRAIN_DATA_URL)\n",
    "test_file_path = tf.keras.utils.get_file(\"eval.csv\", TEST_DATA_URL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 让 numpy 数据更易读。\n",
    "# precision：控制输出结果的精度(即小数点后的位数)，默认值为8\n",
    "# suppress：小数是否需要以科学计数法的形式输出\n",
    "np.set_printoptions(precision=3, suppress=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 加载数据\n",
    "先打印CSV文件的前几行了解文件格式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>survived</th>\n",
       "      <th>sex</th>\n",
       "      <th>age</th>\n",
       "      <th>n_siblings_spouses</th>\n",
       "      <th>parch</th>\n",
       "      <th>fare</th>\n",
       "      <th>class</th>\n",
       "      <th>deck</th>\n",
       "      <th>embark_town</th>\n",
       "      <th>alone</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>male</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>Third</td>\n",
       "      <td>unknown</td>\n",
       "      <td>Southampton</td>\n",
       "      <td>n</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>First</td>\n",
       "      <td>C</td>\n",
       "      <td>Cherbourg</td>\n",
       "      <td>n</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>Third</td>\n",
       "      <td>unknown</td>\n",
       "      <td>Southampton</td>\n",
       "      <td>y</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>First</td>\n",
       "      <td>C</td>\n",
       "      <td>Southampton</td>\n",
       "      <td>n</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>male</td>\n",
       "      <td>28.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8.4583</td>\n",
       "      <td>Third</td>\n",
       "      <td>unknown</td>\n",
       "      <td>Queenstown</td>\n",
       "      <td>y</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>622</th>\n",
       "      <td>0</td>\n",
       "      <td>male</td>\n",
       "      <td>28.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>10.5000</td>\n",
       "      <td>Second</td>\n",
       "      <td>unknown</td>\n",
       "      <td>Southampton</td>\n",
       "      <td>y</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>623</th>\n",
       "      <td>0</td>\n",
       "      <td>male</td>\n",
       "      <td>25.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.0500</td>\n",
       "      <td>Third</td>\n",
       "      <td>unknown</td>\n",
       "      <td>Southampton</td>\n",
       "      <td>y</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>624</th>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>19.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30.0000</td>\n",
       "      <td>First</td>\n",
       "      <td>B</td>\n",
       "      <td>Southampton</td>\n",
       "      <td>y</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>625</th>\n",
       "      <td>0</td>\n",
       "      <td>female</td>\n",
       "      <td>28.0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>23.4500</td>\n",
       "      <td>Third</td>\n",
       "      <td>unknown</td>\n",
       "      <td>Southampton</td>\n",
       "      <td>n</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>626</th>\n",
       "      <td>0</td>\n",
       "      <td>male</td>\n",
       "      <td>32.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7500</td>\n",
       "      <td>Third</td>\n",
       "      <td>unknown</td>\n",
       "      <td>Queenstown</td>\n",
       "      <td>y</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>627 rows × 10 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     survived     sex   age  n_siblings_spouses  parch     fare   class  \\\n",
       "0           0    male  22.0                   1      0   7.2500   Third   \n",
       "1           1  female  38.0                   1      0  71.2833   First   \n",
       "2           1  female  26.0                   0      0   7.9250   Third   \n",
       "3           1  female  35.0                   1      0  53.1000   First   \n",
       "4           0    male  28.0                   0      0   8.4583   Third   \n",
       "..        ...     ...   ...                 ...    ...      ...     ...   \n",
       "622         0    male  28.0                   0      0  10.5000  Second   \n",
       "623         0    male  25.0                   0      0   7.0500   Third   \n",
       "624         1  female  19.0                   0      0  30.0000   First   \n",
       "625         0  female  28.0                   1      2  23.4500   Third   \n",
       "626         0    male  32.0                   0      0   7.7500   Third   \n",
       "\n",
       "        deck  embark_town alone  \n",
       "0    unknown  Southampton     n  \n",
       "1          C    Cherbourg     n  \n",
       "2    unknown  Southampton     y  \n",
       "3          C  Southampton     n  \n",
       "4    unknown   Queenstown     y  \n",
       "..       ...          ...   ...  \n",
       "622  unknown  Southampton     y  \n",
       "623  unknown  Southampton     y  \n",
       "624        B  Southampton     y  \n",
       "625  unknown  Southampton     n  \n",
       "626  unknown   Queenstown     y  \n",
       "\n",
       "[627 rows x 10 columns]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pd.read_csv(train_file_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# 列名\n",
    "CSV_COLUMNS = ['survived', 'sex', 'age', 'n_siblings_spouses', 'parch', 'fare', 'class', 'deck', 'embark_town', 'alone']\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 预测的值的列是需要显式指定的\n",
    "LABEL_COLUMN = 'survived'\n",
    "LABELS = [0, 1]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 从文件中读取 CSV 数据并且创建 dataset\n",
    "def get_dataset(file_path):\n",
    "  dataset = tf.data.experimental.make_csv_dataset(\n",
    "      file_path,\n",
    "      batch_size=12, # 为了示例更容易展示，手动设置较小的值\n",
    "      label_name=LABEL_COLUMN,\n",
    "      na_value=\"?\",\n",
    "      num_epochs=1,\n",
    "      ignore_errors=True)\n",
    "  return dataset\n",
    "\n",
    "raw_train_data = get_dataset(train_file_path)\n",
    "raw_test_data = get_dataset(test_file_path)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "examples: \n",
      " OrderedDict([('sex', <tf.Tensor: shape=(12,), dtype=string, numpy=\n",
      "array([b'male', b'male', b'male', b'male', b'female', b'male', b'male',\n",
      "       b'female', b'male', b'female', b'male', b'female'], dtype=object)>), ('age', <tf.Tensor: shape=(12,), dtype=float32, numpy=\n",
      "array([ 4., 28., 47., 22., 18., 51., 27., 28., 28., 26., 31., 40.],\n",
      "      dtype=float32)>), ('n_siblings_spouses', <tf.Tensor: shape=(12,), dtype=int32, numpy=array([4, 0, 0, 1, 0, 0, 0, 1, 0, 1, 1, 0])>), ('parch', <tf.Tensor: shape=(12,), dtype=int32, numpy=array([2, 0, 0, 0, 2, 0, 0, 0, 0, 1, 0, 0])>), ('fare', <tf.Tensor: shape=(12,), dtype=float32, numpy=\n",
      "array([ 31.275,   7.896,  38.5  ,   7.25 ,  13.   ,   8.05 ,   8.663,\n",
      "        82.171,   0.   ,  26.   ,  52.   , 153.462], dtype=float32)>), ('class', <tf.Tensor: shape=(12,), dtype=string, numpy=\n",
      "array([b'Third', b'Third', b'First', b'Third', b'Second', b'Third',\n",
      "       b'Third', b'First', b'Second', b'Second', b'First', b'First'],\n",
      "      dtype=object)>), ('deck', <tf.Tensor: shape=(12,), dtype=string, numpy=\n",
      "array([b'unknown', b'unknown', b'E', b'unknown', b'unknown', b'unknown',\n",
      "       b'unknown', b'unknown', b'unknown', b'unknown', b'B', b'C'],\n",
      "      dtype=object)>), ('embark_town', <tf.Tensor: shape=(12,), dtype=string, numpy=\n",
      "array([b'Southampton', b'Cherbourg', b'Southampton', b'Southampton',\n",
      "       b'Southampton', b'Southampton', b'Southampton', b'Cherbourg',\n",
      "       b'Southampton', b'Southampton', b'Southampton', b'Southampton'],\n",
      "      dtype=object)>), ('alone', <tf.Tensor: shape=(12,), dtype=string, numpy=\n",
      "array([b'n', b'y', b'y', b'n', b'n', b'y', b'y', b'n', b'y', b'n', b'n',\n",
      "       b'y'], dtype=object)>)]) \n",
      "\n",
      "labels: \n",
      " tf.Tensor([0 0 0 0 1 0 1 1 0 0 0 1], shape=(12,), dtype=int32)\n"
     ]
    }
   ],
   "source": [
    "# dataset 中的每个条目都是一个批次，用一个元组（多个样本，多个标签）表示\n",
    "examples, labels = next(iter(raw_train_data)) # 第一个批次\n",
    "print(\"examples: \\n\", examples, \"\\n\")\n",
    "print(\"labels: \\n\", labels)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 数据预处理\n",
    "#### 【分类数据】\n",
    "CSV 数据中的有些列是分类的列。也就是说，这些列只能在有限的集合中取值。\n",
    "\n",
    "使用 tf.feature_column API 创建一个 tf.feature_column.indicator_column 集合，每个 tf.feature_column.indicator_column 对应一个分类的列。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "CATEGORIES = {\n",
    "    'sex': ['male', 'female'],\n",
    "    'class' : ['First', 'Second', 'Third'],\n",
    "    'deck' : ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'],\n",
    "    'embark_town' : ['Cherbourg', 'Southhampton', 'Queenstown'],\n",
    "    'alone' : ['y', 'n']\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "categorical_columns = []\n",
    "for feature, vocab in CATEGORIES.items():\n",
    "  cat_col = tf.feature_column.categorical_column_with_vocabulary_list(\n",
    "        key=feature, vocabulary_list=vocab)\n",
    "  categorical_columns.append(tf.feature_column.indicator_column(cat_col))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='sex', vocabulary_list=('male', 'female'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
       " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='class', vocabulary_list=('First', 'Second', 'Third'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
       " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='deck', vocabulary_list=('A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
       " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='embark_town', vocabulary_list=('Cherbourg', 'Southhampton', 'Queenstown'), dtype=tf.string, default_value=-1, num_oov_buckets=0)),\n",
       " IndicatorColumn(categorical_column=VocabularyListCategoricalColumn(key='alone', vocabulary_list=('y', 'n'), dtype=tf.string, default_value=-1, num_oov_buckets=0))]"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 刚才创建的内容\n",
    "categorical_columns"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 【连续数据】\n",
    "连续数据需要标准化。\n",
    "\n",
    "写一个函数标准化这些值，然后将这些值改造成 2 维的张量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "def process_continuous_data(mean, data):\n",
    "  # 标准化数据\n",
    "  data = tf.cast(data, tf.float32) * 1/(2*mean)\n",
    "  return tf.reshape(data, [-1, 1])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "创建一个数值列的集合。tf.feature_columns.numeric_column API 会使用 normalizer_fn 参数。在传参的时候使用 functools.partial，functools.partial 由使用每个列的均值进行标准化的函数构成。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "MEANS = {\n",
    "    'age' : 29.631308,\n",
    "    'n_siblings_spouses' : 0.545455,\n",
    "    'parch' : 0.379585,\n",
    "    'fare' : 34.385399\n",
    "}\n",
    "\n",
    "# 数值列集合\n",
    "numerical_columns = []\n",
    "\n",
    "for feature in MEANS.keys():\n",
    "  num_col = tf.feature_column.numeric_column(feature, normalizer_fn=functools.partial(process_continuous_data, MEANS[feature]))\n",
    "  numerical_columns.append(num_col)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[NumericColumn(key='age', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x00000235DA383E50>, 29.631308)),\n",
       " NumericColumn(key='n_siblings_spouses', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x00000235DA383E50>, 0.545455)),\n",
       " NumericColumn(key='parch', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x00000235DA383E50>, 0.379585)),\n",
       " NumericColumn(key='fare', shape=(1,), default_value=None, dtype=tf.float32, normalizer_fn=functools.partial(<function process_continuous_data at 0x00000235DA383E50>, 34.385399))]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 刚才创建的数值列\n",
    "numerical_columns"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### 【创建预处理层】\n",
    "将这两个特征列的集合相加，并且传给 tf.keras.layers.DenseFeatures 从而创建一个进行预处理的输入层。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "preprocessing_layer = tf.keras.layers.DenseFeatures(categorical_columns+numerical_columns)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "\n",
    "### 构建模型\n",
    "从 preprocessing_layer （预处理层）开始构建 tf.keras.Sequential（层的线性叠加）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.Sequential([\n",
    "  preprocessing_layer, # 预处理层\n",
    "  tf.keras.layers.Dense(128, activation='relu'),\n",
    "  tf.keras.layers.Dense(128, activation='relu'),\n",
    "  tf.keras.layers.Dense(1, activation='sigmoid'),\n",
    "])\n",
    "\n",
    "model.compile(\n",
    "    loss='binary_crossentropy',\n",
    "    optimizer='adam',\n",
    "    metrics=['accuracy'])"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 训练、评估和预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = raw_train_data.shuffle(500) # shuffle把数组中的元素按随机顺序重新排列\n",
    "test_data = raw_test_data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/20\n",
      "WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'collections.OrderedDict'> input: OrderedDict([('sex', <tf.Tensor 'ExpandDims_8:0' shape=(None, 1) dtype=string>), ('age', <tf.Tensor 'ExpandDims:0' shape=(None, 1) dtype=float32>), ('n_siblings_spouses', <tf.Tensor 'ExpandDims_6:0' shape=(None, 1) dtype=int32>), ('parch', <tf.Tensor 'ExpandDims_7:0' shape=(None, 1) dtype=int32>), ('fare', <tf.Tensor 'ExpandDims_5:0' shape=(None, 1) dtype=float32>), ('class', <tf.Tensor 'ExpandDims_2:0' shape=(None, 1) dtype=string>), ('deck', <tf.Tensor 'ExpandDims_3:0' shape=(None, 1) dtype=string>), ('embark_town', <tf.Tensor 'ExpandDims_4:0' shape=(None, 1) dtype=string>), ('alone', <tf.Tensor 'ExpandDims_1:0' shape=(None, 1) dtype=string>)])\n",
      "Consider rewriting this model with the Functional API.\n",
      "WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'collections.OrderedDict'> input: OrderedDict([('sex', <tf.Tensor 'ExpandDims_8:0' shape=(None, 1) dtype=string>), ('age', <tf.Tensor 'ExpandDims:0' shape=(None, 1) dtype=float32>), ('n_siblings_spouses', <tf.Tensor 'ExpandDims_6:0' shape=(None, 1) dtype=int32>), ('parch', <tf.Tensor 'ExpandDims_7:0' shape=(None, 1) dtype=int32>), ('fare', <tf.Tensor 'ExpandDims_5:0' shape=(None, 1) dtype=float32>), ('class', <tf.Tensor 'ExpandDims_2:0' shape=(None, 1) dtype=string>), ('deck', <tf.Tensor 'ExpandDims_3:0' shape=(None, 1) dtype=string>), ('embark_town', <tf.Tensor 'ExpandDims_4:0' shape=(None, 1) dtype=string>), ('alone', <tf.Tensor 'ExpandDims_1:0' shape=(None, 1) dtype=string>)])\n",
      "Consider rewriting this model with the Functional API.\n",
      "53/53 [==============================] - 1s 1ms/step - loss: 0.5403 - accuracy: 0.7496\n",
      "Epoch 2/20\n",
      "53/53 [==============================] - 0s 924us/step - loss: 0.4344 - accuracy: 0.8022\n",
      "Epoch 3/20\n",
      "53/53 [==============================] - 0s 885us/step - loss: 0.4174 - accuracy: 0.8198\n",
      "Epoch 4/20\n",
      "53/53 [==============================] - 0s 905us/step - loss: 0.4069 - accuracy: 0.8214\n",
      "Epoch 5/20\n",
      "53/53 [==============================] - 0s 905us/step - loss: 0.3980 - accuracy: 0.8341\n",
      "Epoch 6/20\n",
      "53/53 [==============================] - 0s 943us/step - loss: 0.3951 - accuracy: 0.8278\n",
      "Epoch 7/20\n",
      "53/53 [==============================] - 0s 905us/step - loss: 0.3849 - accuracy: 0.8325\n",
      "Epoch 8/20\n",
      "53/53 [==============================] - 0s 924us/step - loss: 0.3808 - accuracy: 0.8453\n",
      "Epoch 9/20\n",
      "53/53 [==============================] - 0s 924us/step - loss: 0.3744 - accuracy: 0.8325\n",
      "Epoch 10/20\n",
      "53/53 [==============================] - 0s 885us/step - loss: 0.3701 - accuracy: 0.8421\n",
      "Epoch 11/20\n",
      "53/53 [==============================] - 0s 905us/step - loss: 0.3677 - accuracy: 0.8437\n",
      "Epoch 12/20\n",
      "53/53 [==============================] - 0s 905us/step - loss: 0.3558 - accuracy: 0.8485\n",
      "Epoch 13/20\n",
      "53/53 [==============================] - 0s 885us/step - loss: 0.3572 - accuracy: 0.8453\n",
      "Epoch 14/20\n",
      "53/53 [==============================] - 0s 905us/step - loss: 0.3533 - accuracy: 0.8565\n",
      "Epoch 15/20\n",
      "53/53 [==============================] - 0s 924us/step - loss: 0.3525 - accuracy: 0.8549\n",
      "Epoch 16/20\n",
      "53/53 [==============================] - 0s 924us/step - loss: 0.3491 - accuracy: 0.8341\n",
      "Epoch 17/20\n",
      "53/53 [==============================] - 0s 943us/step - loss: 0.3420 - accuracy: 0.8517\n",
      "Epoch 18/20\n",
      "53/53 [==============================] - 0s 943us/step - loss: 0.3450 - accuracy: 0.8469\n",
      "Epoch 19/20\n",
      "53/53 [==============================] - 0s 905us/step - loss: 0.3355 - accuracy: 0.8565\n",
      "Epoch 20/20\n",
      "53/53 [==============================] - 0s 905us/step - loss: 0.3360 - accuracy: 0.8596\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x235e1f88550>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.fit(train_data, epochs=20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'collections.OrderedDict'> input: OrderedDict([('sex', <tf.Tensor 'ExpandDims_8:0' shape=(None, 1) dtype=string>), ('age', <tf.Tensor 'ExpandDims:0' shape=(None, 1) dtype=float32>), ('n_siblings_spouses', <tf.Tensor 'ExpandDims_6:0' shape=(None, 1) dtype=int32>), ('parch', <tf.Tensor 'ExpandDims_7:0' shape=(None, 1) dtype=int32>), ('fare', <tf.Tensor 'ExpandDims_5:0' shape=(None, 1) dtype=float32>), ('class', <tf.Tensor 'ExpandDims_2:0' shape=(None, 1) dtype=string>), ('deck', <tf.Tensor 'ExpandDims_3:0' shape=(None, 1) dtype=string>), ('embark_town', <tf.Tensor 'ExpandDims_4:0' shape=(None, 1) dtype=string>), ('alone', <tf.Tensor 'ExpandDims_1:0' shape=(None, 1) dtype=string>)])\n",
      "Consider rewriting this model with the Functional API.\n",
      "22/22 [==============================] - 0s 1ms/step - loss: 0.4383 - accuracy: 0.8333\n",
      "\n",
      "\n",
      "Test Loss 0.43826165795326233, Test Accuracy 0.8333333134651184\n"
     ]
    }
   ],
   "source": [
    "# 测试评估\n",
    "test_loss, test_accuracy = model.evaluate(test_data)\n",
    "\n",
    "print('\\n\\nTest Loss {}, Test Accuracy {}'.format(test_loss, test_accuracy))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Layers in a Sequential model should only have a single input tensor, but we receive a <class 'collections.OrderedDict'> input: OrderedDict([('sex', <tf.Tensor 'ExpandDims_8:0' shape=(None, 1) dtype=string>), ('age', <tf.Tensor 'ExpandDims:0' shape=(None, 1) dtype=float32>), ('n_siblings_spouses', <tf.Tensor 'ExpandDims_6:0' shape=(None, 1) dtype=int32>), ('parch', <tf.Tensor 'ExpandDims_7:0' shape=(None, 1) dtype=int32>), ('fare', <tf.Tensor 'ExpandDims_5:0' shape=(None, 1) dtype=float32>), ('class', <tf.Tensor 'ExpandDims_2:0' shape=(None, 1) dtype=string>), ('deck', <tf.Tensor 'ExpandDims_3:0' shape=(None, 1) dtype=string>), ('embark_town', <tf.Tensor 'ExpandDims_4:0' shape=(None, 1) dtype=string>), ('alone', <tf.Tensor 'ExpandDims_1:0' shape=(None, 1) dtype=string>)])\n",
      "Consider rewriting this model with the Functional API.\n",
      "Predicted survival: 9.68%  | Actual outcome:  SURVIVED\n",
      "Predicted survival: 7.91%  | Actual outcome:  DIED\n",
      "Predicted survival: 0.36%  | Actual outcome:  DIED\n",
      "Predicted survival: 9.77%  | Actual outcome:  DIED\n",
      "Predicted survival: 75.22%  | Actual outcome:  DIED\n",
      "Predicted survival: 99.39%  | Actual outcome:  DIED\n",
      "Predicted survival: 62.61%  | Actual outcome:  DIED\n",
      "Predicted survival: 89.30%  | Actual outcome:  SURVIVED\n",
      "Predicted survival: 87.70%  | Actual outcome:  SURVIVED\n",
      "Predicted survival: 8.08%  | Actual outcome:  SURVIVED\n"
     ]
    }
   ],
   "source": [
    "# 预测\n",
    "predictions = model.predict(test_data)\n",
    "\n",
    "# 显示部分结果\n",
    "for prediction, survived in zip(predictions[:10], list(test_data)[0][1][:10]):\n",
    "  print(\"Predicted survival: {:.2%}\".format(prediction[0]),\n",
    "        \" | Actual outcome: \",\n",
    "        (\"SURVIVED\" if bool(survived) else \"DIED\"))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 参考\n",
    "https://www.tensorflow.org/tutorials/load_data/csv?hl=zh-cn"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tensorflow",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.16"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
